{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "a5d6ad19-adff-423b-8177-30a0d4f6ceed",
   "metadata": {},
   "outputs": [],
   "source": [
    "# import accelerate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "02e86302-68c0-455a-9457-6a4618b95cba",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: HF_ENDPOINT=https://hf-mirror.com\n"
     ]
    }
   ],
   "source": [
    "%env HF_ENDPOINT=https://hf-mirror.com\n",
    "import os\n",
    "os.environ['HF_HOME'] = '/data1/ckw'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "474af3e5-820c-4d8f-9ee0-e320edc3ccee",
   "metadata": {},
   "outputs": [],
   "source": [
    "os.environ['HF_ENDPOINT']='https://hf-mirror.com'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "e46a03d4-186e-46ee-b9e2-e2826c06e3e4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/data1/ckw/01大语言模型/ChatGLM4\n"
     ]
    }
   ],
   "source": [
    "!pwd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "62284871-67f4-47a2-941f-46befd2032b7",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    }
   ],
   "source": [
    "import torch\n",
    "# from transformers import AutoModelForCausalLM, AutoTokenizer\n",
    "from tokenization_chatglm import ChatGLM4Tokenizer\n",
    "from modeling_chatglm import ChatGLMForConditionalGeneration\n",
    "\n",
    "device = \"cuda\"\n",
    "\n",
    "tokenizer = ChatGLM4Tokenizer.from_pretrained(\"THUDM/glm-4-9b-chat\", trust_remote_code=True)\n",
    "\n",
    "query = \"你好\"\n",
    "\n",
    "inputs = tokenizer.apply_chat_template([{\"role\": \"user\", \"content\": query}],\n",
    "                                       add_generation_prompt=True,\n",
    "                                       tokenize=True,\n",
    "                                       return_tensors=\"pt\",\n",
    "                                       return_dict=True\n",
    "                                       )\n",
    "\n",
    "inputs = inputs.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "5315c9b3-7fc4-4291-a9b1-6bcd92b4d940",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The argument `trust_remote_code` is to be used with Auto classes. It has no effect here and is ignored.\n",
      "The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e68152ca9a1a462d8b96e1e31e7c582b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/10 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n",
      "你好👋！有什么可以帮助你的吗？\n"
     ]
    }
   ],
   "source": [
    "model = ChatGLMForConditionalGeneration.from_pretrained(\n",
    "    \"THUDM/glm-4-9b-chat\",\n",
    "    torch_dtype=torch.bfloat16,\n",
    "    low_cpu_mem_usage=True,\n",
    "    trust_remote_code=True,\n",
    "    load_in_4bit=True,\n",
    "    device_map='auto'\n",
    ").eval()#.to(device)\n",
    "\n",
    "gen_kwargs = {\"max_length\": 2500, \"do_sample\": True, \"top_k\": 1}\n",
    "with torch.no_grad():\n",
    "    outputs = model.generate(**inputs, **gen_kwargs)\n",
    "    outputs = outputs[:, inputs['input_ids'].shape[1]:]\n",
    "    print(tokenizer.decode(outputs[0], skip_special_tokens=True))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "9e6f3cf8-2032-4b0a-be34-6e7f4409a01a",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers.utils import is_accelerate_available"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "1466e0e6-7d99-4962-82c9-4fff85b7b3d4",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "False"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "is_accelerate_available()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "5033e32b-9daf-4b8a-a239-da883414936a",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\u001b[0;31mSignature:\u001b[0m \u001b[0mis_accelerate_available\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmin_version\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m'0.21.0'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
       "\u001b[0;31mDocstring:\u001b[0m <no docstring>\n",
       "\u001b[0;31mSource:\u001b[0m   \n",
       "\u001b[0;32mdef\u001b[0m \u001b[0mis_accelerate_available\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmin_version\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstr\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mACCELERATE_MIN_VERSION\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n",
       "\u001b[0;34m\u001b[0m    \u001b[0;32mreturn\u001b[0m \u001b[0m_accelerate_available\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mversion\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparse\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_accelerate_version\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0mversion\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mparse\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmin_version\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
       "\u001b[0;31mFile:\u001b[0m      /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages/transformers/utils/import_utils.py\n",
       "\u001b[0;31mType:\u001b[0m      function"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "??is_accelerate_available"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "9fef3754-bbe9-4783-877f-b747421ada1e",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple\n",
      "Requirement already satisfied: accelerate in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (0.20.3)\n",
      "Collecting accelerate\n",
      "  Downloading https://pypi.tuna.tsinghua.edu.cn/packages/f0/62/9ebaf1fdd3d3c737a8814f9ae409d4ac04bc93b26a46a7dab456bb7e16f8/accelerate-0.31.0-py3-none-any.whl (309 kB)\n",
      "\u001b[2K     \u001b[38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m309.4/309.4 kB\u001b[0m \u001b[31m2.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m[31m1.8 MB/s\u001b[0m eta \u001b[36m0:00:01\u001b[0m\n",
      "\u001b[?25hRequirement already satisfied: numpy>=1.17 in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from accelerate) (1.24.3)\n",
      "Requirement already satisfied: packaging>=20.0 in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from accelerate) (23.2)\n",
      "Requirement already satisfied: psutil in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from accelerate) (5.9.5)\n",
      "Requirement already satisfied: pyyaml in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from accelerate) (6.0.1)\n",
      "Requirement already satisfied: torch>=1.10.0 in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from accelerate) (2.1.0.post301)\n",
      "Requirement already satisfied: huggingface-hub in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from accelerate) (0.23.2)\n",
      "Requirement already satisfied: safetensors>=0.3.1 in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from accelerate) (0.4.3)\n",
      "Requirement already satisfied: filelock in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from torch>=1.10.0->accelerate) (3.12.4)\n",
      "Requirement already satisfied: typing-extensions in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from torch>=1.10.0->accelerate) (4.9.0)\n",
      "Requirement already satisfied: sympy in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from torch>=1.10.0->accelerate) (1.12)\n",
      "Requirement already satisfied: networkx in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from torch>=1.10.0->accelerate) (3.2.1)\n",
      "Requirement already satisfied: jinja2 in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from torch>=1.10.0->accelerate) (3.1.2)\n",
      "Requirement already satisfied: fsspec in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from torch>=1.10.0->accelerate) (2023.12.2)\n",
      "Requirement already satisfied: requests in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from huggingface-hub->accelerate) (2.31.0)\n",
      "Requirement already satisfied: tqdm>=4.42.1 in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from huggingface-hub->accelerate) (4.65.0)\n",
      "Requirement already satisfied: MarkupSafe>=2.0 in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from jinja2->torch>=1.10.0->accelerate) (2.1.3)\n",
      "Requirement already satisfied: charset-normalizer<4,>=2 in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from requests->huggingface-hub->accelerate) (3.3.0)\n",
      "Requirement already satisfied: idna<4,>=2.5 in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from requests->huggingface-hub->accelerate) (3.4)\n",
      "Requirement already satisfied: urllib3<3,>=1.21.1 in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from requests->huggingface-hub->accelerate) (1.26.18)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from requests->huggingface-hub->accelerate) (2023.7.22)\n",
      "Requirement already satisfied: mpmath>=0.19 in /data1/ckw/micromamba/envs/kewei-ai/lib/python3.11/site-packages (from sympy->torch>=1.10.0->accelerate) (1.3.0)\n",
      "Installing collected packages: accelerate\n",
      "  Attempting uninstall: accelerate\n",
      "    Found existing installation: accelerate 0.20.3\n",
      "    Uninstalling accelerate-0.20.3:\n",
      "      Successfully uninstalled accelerate-0.20.3\n",
      "Successfully installed accelerate-0.31.0\n",
      "Note: you may need to restart the kernel to use updated packages.\n"
     ]
    }
   ],
   "source": [
    "%pip install --upgrade accelerate"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "38eb42ba-0628-4f9d-bf29-410607552950",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "kewei-ai",
   "language": "python",
   "name": "kewei-ai"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
